#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
//#include <pthread.h>
#include <assert.h>
//#include <stdbool.h>
#include <netinet/in.h>
#include "binary_protocol.h"
#include "general.h"
#include "mc.h"
#include "util.h"

#include <event2/event.h>
#include <event2/bufferevent.h>
#include <binary_protocol.h>
//#include <event2/listener.h>
//#include <event2/thread.h>

/* threads */
pthread_mutex_t stats_mutex;
pthread_mutex_t fd_mutex;

/* sockets */
int tcp_fd;
int u_sock_fd;
struct hostent *u_server;
struct sockaddr_in u_server_addr;
extern int **server_fds;

/* env variables */
char *key_prefix;
int num_connections;
int num_requests;
int num_keys;
int k_length;
int v_length;
//int poll_wait_sec;
int rec_interval;
int rec_round;
unsigned long seed;
char *config_file = NULL;

typedef struct ev_timer_st {
	struct event *ev;
	struct timeval to;
	int timer_count;
	bool is_running;
	long last_finished;
} ev_timer;
ev_timer timer;

struct event_base *base;

typedef struct count_stats_st {
	long num_set_req;
	long num_set_resp_success;
	long num_set_resp_fail;
	long num_get_req;
	long num_get_resp_success;
	long num_get_resp_fail;
	uint8_t cmd_mode;
} count_stats;
count_stats stats;

typedef struct k_v_pair_st {
	char key[MAX_KEY_LEN];
	char value[MAX_VAL_LEN];
	//size_t key_len;
	//size_t value_len;
} k_v_pair;

k_v_pair **k_v_pair_list;

const char key_text[] = "123456789";
const char value_text[] = "abcdefghijklmnopqrstuvwxyz";

void socket_ev_cb(struct bufferevent *bev, short events, void *arg);
void client_recv_cb(struct bufferevent *bev, void *arg);
void setup_timer(struct event_base *base, int sec);

typedef struct cli_r_cb_arg_st {
	char *req_buf;
	char *resp_buf;
} cli_r_cb_arg;

/* Thread main loop */
void *thread_main(void *arg) {
	while (1) {
	}
	//return (void*)0;
}

void display_stats() {
	fprintf(stderr, "==== stats ====\n");
	fprintf(stderr, "set(req=%ld suc=%ld fail=%ld) get(req=%ld suc=%ld fail=%ld)\n",
			stats.num_set_req, stats.num_set_resp_success, stats.num_set_resp_fail,
			stats.num_get_req, stats.num_get_resp_success, stats.num_set_resp_fail);
}

void init_default_env() {
	num_connections = DEFAULT_CONNECTIONS;
	num_requests = DEFAULT_REQUESTS;
	num_keys = DEFAULT_KEYS;
	//poll_wait_sec = DEFAULT_POLL_WAIT_SEC;
	seed = DEFAULT_SEED;
}

void usage(char *prog)
{
	printf("Usage: %s [-c n] [-g n] [-i n] [-k n] [-r n] [-s s] [-t n] host port\n", prog);
	printf("\t-c n Open n connections (default: %d)\n",
		DEFAULT_CONNECTIONS);
	printf("\t-k n Number of keys in key space (default: %d)\n",
		DEFAULT_KEYS);
	printf("\t-r n Requests to send simultaneously (default: %d)\n",
		DEFAULT_REQUESTS);
	printf("\t-s s Seed to use for random number generation (default: %d)\n",
		DEFAULT_SEED);
	printf("\t-t n Time out requests after n seconds (default: %d)\n",
		DEFAULT_POLL_WAIT_SEC);
	exit(1);
}

/* Returns an integer argument from getopt */
int intarg(char *prog, char *arg)
{
	int x = atoi(arg);
	if (x == 0)
		usage(prog);
	return x;
}

void gen_random(const char *src, char *dst, int len, char *prefix) {
	size_t tmp_len = (size_t)len;
	size_t prefix_len = 0;
	if (prefix != NULL) {
		prefix_len = strlen(prefix);
		tmp_len -= prefix_len;
		strncpy(dst, prefix, strlen(prefix));
	}
	for (int i=0; i<tmp_len; i++) {
		dst[prefix_len+i] = src[rand()%(sizeof(src)-1)];
	}
}

void init_k_v_pair_list(int num_keys, char *prefix) {
	k_v_pair_list = (k_v_pair **)malloc(num_keys * sizeof(k_v_pair *));
	for (int i=0; i<num_keys; i++) {
		k_v_pair_list[i] = (k_v_pair *)malloc(sizeof(k_v_pair));
		/*
		//pre-fill
		memset(k_v_pair_list[i]->key, 'a', k_length);
		memset(k_v_pair_list[i]->value, 'b', v_length);

		size_t kl = snprintf(NULL, 0, "%s%d", prefix, i);
		sprintf(k_v_pair_list[i]->key, "%s%d", prefix, i);
		sprintf(k_v_pair_list[i]->key + kl, "testkey");

		size_t vl = snprintf(NULL, 0, "value%d", i);
		sprintf(k_v_pair_list[i]->value, "test%dvalue%d", i, i);
		 */

		gen_random(key_text, k_v_pair_list[i]->key, k_length, prefix);
		gen_random(value_text, k_v_pair_list[i]->value, v_length, NULL);
		fprintf(stderr, "init [%d] k=%s v=%s\n", i, k_v_pair_list[i]->key,
				k_v_pair_list[i]->value);
	}
	fprintf(stderr, "finish init k_v pair list k_length=%d v_length=%d pairs=%d\n",
			k_length, v_length, num_keys);
}

void release_k_v_pair_list(int num_keys) {
	if (NULL == k_v_pair_list)
		exit(EXIT_FAILURE);
	for (int i=0; i<num_keys; i++) {
		if (NULL != k_v_pair_list[i])
			free(k_v_pair_list[i]);
	}
	free(k_v_pair_list);
}

char *get_value_by_key(char *key, int k_len) {
	for (int i=0; i<num_keys; i++) {
		if (0 == strncmp(key, k_v_pair_list[i]->key, k_len)) {
			return k_v_pair_list[i]->value;
		}
	}
	return NULL;
}

void dump_buffer_content(char *buf, int len, int mode) {
	fprintf(stderr, "\n< dump start  ");
	for (int i=0; i<len; i++) {
		if (0==i%4)
			fprintf(stderr, "\n dump > ");
		if (1 == mode) { //ascii mode
			fprintf(stderr, "0x%02x ", buf[i]);
		} else { //char mode
			fprintf(stderr, "%c ", buf[i]);
		}
	}
	fprintf(stderr, "  dump end >\n");
}

void process_response(struct bufferevent *bev, char *buf, int len) {
	int parsed_bytes = 0;
	binary_header_t *h;
	uint16_t k_len;
	uint32_t body_len;
	uint8_t ext_len;
	int v_len;
	char *exp_get_value;
	int op_type;

	while (parsed_bytes < len) {
		h = (binary_header_t *)(buf+parsed_bytes);
		k_len = ntohs(h->key_len);
		body_len = ntohl(h->body_len);
		ext_len = h->extra_len;
		v_len = body_len - k_len - ext_len;
		op_type = h->opcode;
		if (CMD_GETK == op_type) {
			assert(ext_len == 4 && v_len > 0);
			exp_get_value = get_value_by_key(buf+parsed_bytes + 24+ext_len, k_len);
			assert(exp_get_value != NULL);
			if (0 == strncmp(buf+parsed_bytes+ 24+ext_len+k_len, exp_get_value, (size_t) v_len)) {
				//fprintf(stderr, "success find k_v pair\n");
				stats.num_get_resp_success++;
			} else {
				//fprintf(stderr, "ERROR find k_v pair <dump ext=%d k_len=%d v_len=%d\n",
				//		ext_len, k_len, v_len);
				//dump_buffer_content(buf, 24, 1); //print header
				//dump_buffer_content(buf+parsed_bytes, body_len-ext_len, 0); //print body w/o extra
				stats.num_get_resp_fail++;
			}
			parsed_bytes += 24+body_len;

			if (!timer.is_running || stats.num_get_resp_success == 2*stats.num_set_resp_success) {
				stats.cmd_mode = CMD_NOOP;
				//fprintf(stderr, "exit the program get(S=%ld R=%ld F=%ld) set(S=%ld R=%ld F=%ld)\n",
				//		stats.num_get_req, stats.num_get_resp_success, stats.num_get_resp_fail,
				//stats.num_set_req, stats.num_set_resp_success, stats.num_set_resp_fail);
				bufferevent_free(bev);
			}

		} else if (CMD_SET == op_type) {
			assert(ext_len == 0 && body_len == 0);
			if (0 == ntohs(h->status)) {
				//fprintf(stderr, "op=%d success\n", op_type);
				stats.num_set_resp_success++;
			} else {
				//fprintf(stderr, "op=%d fail code=%d\n", op_type, ntohs(h->status));
				stats.num_set_resp_fail++;
			}
			parsed_bytes += 24;

			if (stats.num_set_resp_success == num_keys) {
				stats.cmd_mode = CMD_GETK;
				timer.is_running = true;
				setup_timer(base, rec_interval);
			}

		} else {
			fprintf(stderr, "get unexpected response\n");
		}

	}
}

int process_request(char *req_buf) {
	int key_idx;
	int req_body_len = 0;
	if (stats.cmd_mode == CMD_SET) {
		key_idx = (int) stats.num_set_req;
		compose_binary_set(req_buf, k_v_pair_list[key_idx]->key,
						   k_v_pair_list[key_idx]->value, &req_body_len);
		req_body_len += 32;
		stats.num_set_req++;

	} else if (stats.cmd_mode == CMD_GETK) {
		key_idx = rand() % num_keys;
		req_body_len = compose_binary_get(req_buf, k_v_pair_list[key_idx]->key, CMD_GETK, 0, 0);
		stats.num_get_req++;
	} else {
		//skip for now
	}
	return req_body_len;
}

void timer_cb(int sock, short which, void *arg) {
	//if (!evtimer_pending(timer.ev, NULL)) {
		if (timer.timer_count < rec_round) {
			timer.timer_count++;
			fprintf(stderr, "# round=%d get_op tps=%.1f\n",
					timer.timer_count, 1.0*(stats.num_get_resp_success- timer.last_finished)/timer.to.tv_sec);
			timer.last_finished = stats.num_get_resp_success;
			//evtimer_add(timer.ev, &timer.to);
		} else {
			timer.is_running = false;
			event_del(timer.ev);
			fprintf(stderr, "exit timer_cb\n");
			display_stats();
		}
	//}
}

void setup_timer(struct event_base *base, int sec) {
	timer.to.tv_sec = sec;
	timer.to.tv_usec = 0;
	timer.last_finished = 0;
	timer.ev = event_new(base, -1, EV_PERSIST, timer_cb, NULL);
	evtimer_add(timer.ev, &timer.to);
}

int main(int argc, char **argv) {
	pthread_t thr;
	char *host;
	int port;
	int i;
	char random_state[16];
	int c;

	init_default_env();

	while ((c = getopt(argc, argv, "c:k:r:u:x:a:f:v:l:i:n:")) != EOF) {
		switch (c) {
			case 'a':
				use_ascii_protocol = atoi(optarg) ? true : false;
				break;
			case 'c':
				num_connections = intarg(argv[0], optarg);
				break;
			case 'k':
				num_keys = intarg(argv[0], optarg);
				break;
			case 'v':
				v_length = intarg(argv[0], optarg);
				break;
			case 'l':
				k_length = intarg(argv[0], optarg);
				break;
			case 'r':
				num_requests = intarg(argv[0], optarg);
				break;
			case 's':
				seed = intarg(argv[0], optarg);
				break;
			case 'i':
				//poll_wait_sec = intarg(argv[0], optarg);
				rec_interval = intarg(argv[0], optarg);
				break;
			case 'n':
				rec_round = intarg(argv[0], optarg);
				break;
			case 'u':
				enable_udp = atoi(optarg) ? true : false;
				break;
			case 'x':
				key_prefix = strdup(optarg);
				break;
			case 'f':
				config_file = strdup(optarg);
				break;
			default:
				usage(argv[0]);
				break;
		}
	}

	initstate(seed, random_state, sizeof(random_state));
	if (NULL != config_file) {
		parse_ini_file(config_file);
	} else {
		exit(EXIT_FAILURE);
	}

	/* at least has host, port parameters */
	if (argc < optind + 2)
		usage(argv[0]);

	host = argv[optind++];
	port = intarg(argv[0], argv[optind]);

	if (enable_udp) {
		//create_udp_connection(host, port);
		exit(EXIT_FAILURE);
	} else {
		tcp_fd = clientsock(host, port);
		//set_nodelay(tcp_fd, 0);
	}

	//char key_list[16][32];
	//char value_list[16][32];
	//assert(num_keys < 16);
	//char **key_list_ptr = (char **)malloc(sizeof(char *)*16);

	int sfd = (enable_udp)? u_sock_fd : tcp_fd;

	init_k_v_pair_list(num_keys, key_prefix);

	/*
	for (int i=0; i<num_keys; i++) {
		if (use_ascii_protocol) {
			ascii_set_request(sfd, k_v_pair_list[i]->key_list, k_v_pair_list[i]->value_list, v_length);
			ascii_parse_set_response(sfd);
		} else {
			binary_set_or_noop(k_v_pair_list[i]->key_list, k_v_pair_list[i]->value_list, false);
		}
		fprintf(stderr, "finish set <%s> <%s>\n", k_v_pair_list[i]->key_list,
				k_v_pair_list[i]->value_list);
	}

	char temp_buf[MAX_GET_BUF_SIZE];
	for (int i=0; i<2; i++) {
		memset(temp_buf, 0, MAX_GET_BUF_SIZE);
		size_t nbytes = binary_mget_req(temp_buf, key_list_ptr+ i*num_keys/2, num_keys/2);
		send_request(tcp_fd, temp_buf, nbytes);
		memset(temp_buf, 0, MAX_GET_BUF_SIZE);
		ssize_t rbytes = recv_request(tcp_fd, temp_buf, MAX_GET_BUF_SIZE);
		fprintf(stderr, "\n\n...finish %d mget...\n\n", i+1);
	}

	binary_set_or_noop(key_list[num_keys-1], value_list[num_keys-1], false);
	free(key_list_ptr);
	fprintf(stderr, "finish test\n");
	 */

	base = event_base_new();
	assert(NULL != base);

	struct bufferevent *bev;

	cli_r_cb_arg r_cb_arg;
	r_cb_arg.req_buf = (char *)malloc(sizeof(char) * MAX_SET_BUF_SIZE);
	r_cb_arg.resp_buf = (char *)malloc(sizeof(char) * MAX_SET_BUF_SIZE);

	srand(time(NULL));
	evutil_make_socket_nonblocking(sfd);
	bev = bufferevent_socket_new(base, sfd, BEV_OPT_CLOSE_ON_FREE);
	bufferevent_setcb(bev, client_recv_cb, NULL, socket_ev_cb, (void *)&r_cb_arg);
	bufferevent_enable(bev, EV_READ|EV_PERSIST);

	stats.cmd_mode = CMD_SET; //init set first
	// ignite the first set_request
	int req_len = process_request(r_cb_arg.req_buf);
	int ret = bufferevent_write(bev, r_cb_arg.req_buf, (size_t) req_len);
	assert(ret == 0);

	event_base_dispatch(base);
	fprintf(stderr, "finish dispatch\n");

	event_base_free(base);

	if (NULL != r_cb_arg.req_buf)
		free(r_cb_arg.req_buf);
	if (NULL != r_cb_arg.resp_buf)
		free(r_cb_arg.resp_buf);

	release_k_v_pair_list(num_keys);

	//if (bev != NULL)
	//	free(bev);
	fprintf(stderr, "finish free base\n");

	return 0;
}

void socket_ev_cb(struct bufferevent *bev, short events, void *arg) {
	evutil_socket_t fd = bufferevent_getfd(bev);
	fprintf(stderr, "socket_event_cb fd=%u\n", fd);
	if (events & BEV_EVENT_EOF) {
		fprintf(stderr, "connection closed\n");
	} else if (events & BEV_EVENT_ERROR) {
		fprintf(stderr, "some other error\n");
	} else if (events & BEV_EVENT_TIMEOUT) {
		fprintf(stderr, "time out error\n");
	}
	bufferevent_free(bev);
}

void client_recv_cb(struct bufferevent *bev, void *arg) {
	cli_r_cb_arg *cb_arg = (cli_r_cb_arg *)arg;
	ssize_t len = bufferevent_read(bev, cb_arg->resp_buf, 1024);
	assert(len >=24);
	cb_arg->resp_buf[len] = '\0';
	process_response(bev, cb_arg->resp_buf, len);
	//while (CMD_GETK == stats.cmd_mode && stats.num_get_req < stats.num_get_resp_success+5) {
		int req_len = process_request(cb_arg->req_buf);
		int ret = bufferevent_write(bev, cb_arg->req_buf, (size_t) req_len);
		assert(ret == 0);
	//}

	//sleep(1);
}
